{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings \n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1500, 768])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "class Atten(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.q = torch.nn.Linear(768, 768, bias=True)\n",
    "        self.k = torch.nn.Linear(768, 768, bias=False)\n",
    "        self.v = torch.nn.Linear(768, 768, bias=True)\n",
    "        self.out = torch.nn.Linear(768, 768, bias=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        q = self.q(x) * 0.125\n",
    "        k = self.k(x)\n",
    "        v = self.v(x)\n",
    "\n",
    "        #[2, 1500, 768] -> [2, 1500, 12, 64] -> [2, 12, 1500, 64] -> [24, 1500, 64]\n",
    "        q = q.reshape(-1, 1500, 12, 64).transpose(1, 2).reshape(-1, 1500, 64)\n",
    "        k = k.reshape(-1, 1500, 12, 64).transpose(1, 2).reshape(-1, 1500, 64)\n",
    "        v = v.reshape(-1, 1500, 12, 64).transpose(1, 2).reshape(-1, 1500, 64)\n",
    "\n",
    "        #[24, 1500, 64] * [24, 64, 1500] -> [24, 1500, 1500]\n",
    "        #[24, 1500, 1500] * [24, 1500, 64] -> [24, 1500, 64]\n",
    "        atten = q.bmm(k.transpose(1, 2)).softmax(dim=-1).bmm(v)\n",
    "\n",
    "        #[24, 1500, 64] -> [2, 12, 1500, 64] -> [2, 1500, 12, 64] -> [2, 1500, 768]\n",
    "        atten = atten.reshape(-1, 12, 1500,\n",
    "                              64).transpose(1, 2).reshape(-1, 1500, 768)\n",
    "\n",
    "        atten = self.out(atten)\n",
    "\n",
    "        return atten\n",
    "\n",
    "\n",
    "Atten()(torch.randn(2, 1500, 768)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1500, 768])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Layer(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.s1 = torch.nn.Sequential(\n",
    "            torch.nn.LayerNorm(768),\n",
    "            Atten(),\n",
    "        )\n",
    "\n",
    "        self.s2 = torch.nn.Sequential(\n",
    "            torch.nn.LayerNorm(768),\n",
    "            torch.nn.Linear(768, 3072),\n",
    "            torch.nn.GELU(),\n",
    "            torch.nn.Linear(3072, 768),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.s1(x) + x\n",
    "        return self.s2(x) + x\n",
    "\n",
    "\n",
    "Layer()(torch.randn(2, 1500, 768)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1500, 768])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Encoder(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.s1 = torch.nn.Sequential(\n",
    "            torch.nn.Conv1d(80, 768, kernel_size=3, stride=1, padding=1),\n",
    "            torch.nn.GELU(),\n",
    "            torch.nn.Conv1d(768, 768, kernel_size=3, stride=2, padding=1),\n",
    "            torch.nn.GELU(),\n",
    "        )\n",
    "\n",
    "        self.embed = torch.nn.Embedding(1500, 768)\n",
    "\n",
    "        s2 = [Layer() for _ in range(12)]\n",
    "        s2.append(torch.nn.LayerNorm(768))\n",
    "        self.s2 = torch.nn.Sequential(*s2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.s1(x).permute(0, 2, 1) + self.embed.weight\n",
    "\n",
    "        return self.s2(x)\n",
    "\n",
    "\n",
    "Encoder()(torch.randn(2, 80, 3000)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def load_encoder(pretrained):\n",
    "    encoder = Encoder()\n",
    "\n",
    "    encoder.s1[0].load_state_dict(pretrained.conv1.state_dict())\n",
    "    encoder.s1[2].load_state_dict(pretrained.conv2.state_dict())\n",
    "    encoder.embed.load_state_dict(pretrained.embed_positions.state_dict())\n",
    "\n",
    "    for i in range(12):\n",
    "        encoder.s2[i].s1[1].q.load_state_dict(\n",
    "            pretrained.layers[i].self_attn.q_proj.state_dict())\n",
    "        encoder.s2[i].s1[1].k.load_state_dict(\n",
    "            pretrained.layers[i].self_attn.k_proj.state_dict())\n",
    "        encoder.s2[i].s1[1].v.load_state_dict(\n",
    "            pretrained.layers[i].self_attn.v_proj.state_dict())\n",
    "        encoder.s2[i].s1[1].out.load_state_dict(\n",
    "            pretrained.layers[i].self_attn.out_proj.state_dict())\n",
    "\n",
    "        encoder.s2[i].s1[0].load_state_dict(\n",
    "            pretrained.layers[i].self_attn_layer_norm.state_dict())\n",
    "        encoder.s2[i].s2[0].load_state_dict(\n",
    "            pretrained.layers[i].final_layer_norm.state_dict())\n",
    "        encoder.s2[i].s2[1].load_state_dict(\n",
    "            pretrained.layers[i].fc1.state_dict())\n",
    "        encoder.s2[i].s2[3].load_state_dict(\n",
    "            pretrained.layers[i].fc2.state_dict())\n",
    "\n",
    "    encoder.s2[12].load_state_dict(pretrained.layer_norm.state_dict())\n",
    "\n",
    "    return encoder\n",
    "\n",
    "\n",
    "from transformers import WhisperForConditionalGeneration\n",
    "\n",
    "pretrained = WhisperForConditionalGeneration.from_pretrained(\n",
    "    'models/whisper-small').model.encoder\n",
    "encoder = load_encoder(pretrained)\n",
    "\n",
    "x = torch.randn(2, 80, 3000)\n",
    "(encoder(x) == pretrained(x).last_hidden_state).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "simple",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
